
def modify_init(u3d,v3d,w3d,k3d,om3d,eps3d,vis3d):
   
   data=np.loadtxt('y_u_k_om_uv_395.dat')

   u_rans=data[:,1]
# make it 2D
   u_rans=np.repeat(u_rans[:,None], repeats=nk, axis=1)

# set inlet field in entre domain
   u3d=np.repeat(u_rans[None,:,:], repeats=ni, axis=0)

   return u3d,v3d,w3d,k3d,om3d,eps3d,vis3d

def bc_outlet(convw):

# inlet
   flow_in=np.sum(convw[0,:,:])
   flow_out=np.sum(convw[-1,:,:])
   area_out=np.sum(areaw[-1,:,:])

   uinc=(flow_in-flow_out)/area_out
   ares=areaw[-1,:,:]
   convw[-1,:,:]=convw[-1,:,:]+uinc*ares

   print('u3d,u3d_face_w[-1,40,1]',u3d[-1,40,1],u3d_face_w[-1,40,1])

   flow_out_new=np.sum(convw[-1,:,:])

   print('flow_in',flow_in,'flow_out',flow_out,'area_out',area_out,'flow_out_new',flow_out_new,'uinc:',uinc)

   return convw,u_bc_east


   return su3d,sp3d

def compute_inlet_fluct():
   from synt_fluct import synt_fluct

   global y_rans,y_rans,u_rans,k_rans,om_rans,uv_rans,zp,a_synt,b_synt,usynt_inlet,vsynt_inlet,wsynt_inlet,\
          uu_synt,vv_synt,ww_synt,uv_synt
   if itstep == 0:
      y_u_k_om=np.loadtxt('y_u_k_om_uv_395.dat')
      y_rans=y_u_k_om[:,0]
      u_rans=y_u_k_om[:,1]
# make it 2D
      u_rans=np.repeat(u_rans[:,None], repeats=nk, axis=1)
      k_rans=y_u_k_om[:,2]
      om_rans=y_u_k_om[:,3]
      uv_rans=np.abs(y_u_k_om[:,4])

# z grid
      zp = np.linspace(0, zmax, nk)
      usynt,vsynt,wsynt=synt_fluct(nmodes_synt,itstep,L_t_synt,y_rans,zp,uv_rans,viscos,jmirror_synt)
# correct usynt so that it is = 0 (easier to converge the p solver)
      usynt=usynt-np.mean(usynt)
      usynt_inlet=usynt
      vsynt_inlet=vsynt
      wsynt_inlet=wsynt
      uin=1.
      uin=np.sum(convw[0,:,:])/(y2d[0,-1]-y2d[0,0])/zmax
      tturb=L_t_synt/uin
      a_synt=np.exp(-dt[itstep]/tturb)
      b_synt=(1.-a_synt**2)**0.5
      print('a_synt,b_synt',a_synt,b_synt)
      uu_synt=np.zeros(nj)
      vv_synt=np.zeros(nj)
      ww_synt=np.zeros(nj)
      uv_synt=np.zeros(nj)

   usynt,vsynt,wsynt=synt_fluct(nmodes_synt,itstep,L_t_synt,y_rans,zp,uv_rans,viscos,jmirror_synt)
# correct usynt so that it is = 0 (easier to converge the p solver)
   usynt=usynt-np.mean(usynt)
   usynt_inlet=a_synt*usynt_inlet+b_synt*usynt
   vsynt_inlet=a_synt*vsynt_inlet+b_synt*vsynt
   wsynt_inlet=a_synt*wsynt_inlet+b_synt*wsynt
   u_bc_west=u_rans+usynt_inlet
   v_bc_west=vsynt_inlet
   w_bc_west=wsynt_inlet

   uu_synt=uu_synt+np.mean(usynt**2,axis=1)
   vv_synt=vv_synt+np.mean(vsynt**2,axis=1)
   ww_synt=ww_synt+np.mean(wsynt**2,axis=1)
   uv_synt=uv_synt+np.mean(usynt*vsynt,axis=1)

   if itstep%100 == 0:
      print('uu_synt',uu_synt/(itstep+1))
      print('uv_synt',uv_synt/(itstep+1))

# update face velocity and convw at inlet
   u3d_face_w[0,:,:]=u_bc_west
   convw[0,:,:]=-u_bc_west*areawx[0,:,:]-v_bc_west*areawy[0,:,:]

   return u_bc_west,v_bc_west,w_bc_west,u3d_face_w,convw

def modify_u(su3d,sp3d):

   su3d[0,:,:]= su3d[0,:,:]+convw[0,:,:]*u_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-convw[0,:,:]
   if itstep % 100 == 0:
      print('modu: u_bc_west[:,20]', u_bc_west[:,20])

   return su3d,sp3d


def modify_v(su3d,sp3d):
   su3d[0,:,:]= su3d[0,:,:]+convw[0,:,:]*v_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-convw[0,:,:]

   if itstep % 100 == 0:
      print('modv: v_bc_west[:,20]', v_bc_west[:,20])
   return su3d,sp3d


def modify_w(su3d,sp3d):
   su3d[0,:,:]= su3d[0,:,:]+convw[0,:,:]*w_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-convw[0,:,:]
   if itstep % 100 == 0:
      print('modw: w_bc_west[:,20]', w_bc_west[:,20])
   return su3d,sp3d


def modify_k(su3d,sp3d):
   return su3d,sp3d


def modify_eps(su3d,sp3d):

   return su3d,sp3d,eps3d

def modify_om(su3d,sp3d):

   return su3d,sp3d

def case_setup():
   import numpy as np
   import sys

# N.B. All variables that are set in this module must be included in the 'return' statement at the last line

   scheme='h'  #hybrid
   scheme='u'  #central differencing
   scheme_keps='u'  #hybrid upwind-central 
#  scheme_keps='h'  #hybrid upwind-central 
   cmu=0.09

   pans = False
   keps = False
   kom_des = False
   kom = False
   wale = True
   smag = False

   c_eps_1=1.44
   c_eps_2=1.82
   cmu=0.09
   c_omega_1= 5./9.
   c_omega_2=3./40.
   prand_omega=2.0
   prand_eps=1.3
   prand_k=1.3

   if keps:
      prand_k=1.0 

   if pans:
      prand_k=-1.0 # will be multiplied by fk3d in coeff()
      prand_eps=-1.3 # will be multiplied by fk3d in coeff()

   if kom or kom_des:
      prand_k=2.0
   if smag:
      cmu=0.1

   restart = False
   save = True
   cyclic_x = False
   cyclic_z = True
   viscos=1/395
#  viscos=1/50
   urfvis=0.5
   maxit=5
   itstep_save=2000  # save every itstep_save timestep
   itstep_average=1 # time average every itstep_average timestep
   ntstep=20000
   itstep_start=10000
   sormax=1e-4
   sormax=1e-3
   acrank=0.5  # for pressure gradient
   acrank=1.0
   acrank_conv=0.5  # for convection-diffusion
   acrank_conv_kom=1  # for convection-diffusion
   acrank_conv_kom=0.5  # for convection-diffusion
   acrank_conv_keps=0.5  # for convection-diffusion
   acrank_conv_keps=0.5  # for convection-diffusion
   acrank_conv_keps=1  # for convection-diffusion

#  acrank_conv=1.0  # for convection-diffusion
   imon=0
   jmon=0
   kmon=0


   nsweep_vel=500
   nsweep_keps=5
   nsweep_kom=5
   convergence_limit_vel=1e-5

   convergence_limit_eps=1e-4
   convergence_limit_k=1e-4
   convergence_limit_om=1e-10
   convergence_limit_p=5e-4
   nsweep_p=1000

   datax= np.loadtxt("x2d.dat")
   x=datax[0:-1]
   ni=int(datax[-1])
   datay= np.loadtxt("y2d.dat")
   y=datay[0:-1]
   nj=int(datay[-1])

   x2d=np.zeros((ni+1,nj+1))
   y2d=np.zeros((ni+1,nj+1))

   x2d=np.reshape(x,(ni+1,nj+1))
   y2d=np.reshape(y,(ni+1,nj+1))

# compute cell centers
   xp2d=0.25*(x2d[0:-1,0:-1]+x2d[0:-1,1:]+x2d[1:,0:-1]+x2d[1:,1:])
   yp2d=0.25*(y2d[0:-1,0:-1]+y2d[0:-1,1:]+y2d[1:,0:-1]+y2d[1:,1:])

   zmax, nk=np.loadtxt('z.dat')
   nk=np.int(nk)
   dz=zmax/nk

   uin=20

   dt=0.4*(x2d[1,0]-x2d[0,0])*np.ones(ntstep)/uin

# synthetic inlet fluct
   L_t_synt=0.2
   nmodes_synt=150
   jmirror_synt=np.int(nj/2) # mirror vsynt at node jmirror; jmirror=0 means no mirroring

# boundary conditions for u
   u_bc_west=np.zeros((nj,nk))
   u_bc_east=np.zeros((nj,nk))
   u_bc_south=np.zeros((ni,nk))
   u_bc_north=np.zeros((ni,nk))
   u_bc_z=0

   u_bc_west_type='d' 
   u_bc_east_type='n' 
   u_bc_south_type='d'
   u_bc_north_type='d'
   u_bc_z_type='n'


# boundary conditions for v
   v_bc_west=np.zeros((nj,nk))
   v_bc_east=np.zeros((nj,nk))
   v_bc_south=np.zeros((ni,nk))
   v_bc_north=np.zeros((ni,nk))
   v_bc_z=0

   v_bc_west_type='d' 
   v_bc_east_type='n' 
   v_bc_south_type='d'
   v_bc_north_type='d'
   v_bc_z_type='n'

# boundary conditions for w
   w_bc_west=np.zeros((nj,nk))
   w_bc_east=np.zeros((nj,nk))
   w_bc_south=np.zeros((ni,nk))
   w_bc_north=np.zeros((ni,nk))
   w_bc_z=0

   w_bc_west_type='d' 
   w_bc_east_type='n' 
   w_bc_south_type='d'
   w_bc_north_type='d'
   w_bc_z_type='d'

# boundary conditions for p
   p_bc_west=np.zeros((nj,nk))
   p_bc_east=np.zeros((nj,nk))
   p_bc_south=np.zeros((ni,nk))
   p_bc_north=np.zeros((ni,nk))
   p_bc_z=0

   p_bc_west_type='n'
   p_bc_east_type='n'
   p_bc_south_type='n'
   p_bc_north_type='n'
   p_bc_z_type='n'


# boundary conditions for k
   k_bc_west=np.zeros((nj,nk))
   k_bc_east=np.zeros((nj,nk))
   k_bc_south=np.zeros((ni,nk))
   k_bc_north=np.zeros((ni,nk))
   k_bc_z=0

   eps_bc_west=np.zeros((nj,nk))
   eps_bc_east=np.zeros((nj,nk))
   eps_bc_south=np.zeros((ni,nk))
   eps_bc_north=np.zeros((ni,nk))

   k_bc_west_type='d'
   k_bc_east_type='d'
   k_bc_south_type='d'
   k_bc_north_type='d'
   k_bc_z_type='n'

   eps_bc_z=0
   eps_bc_west_type='d'
   eps_bc_east_type='d'
   eps_bc_south_type='d' 
   eps_bc_north_type='d' 
   eps_bc_z_type='n'

# boundary conditions for om
   om_bc_west=np.zeros((nj,nk))
   om_bc_east=np.zeros((nj,nk))
   om_bc_south=np.zeros((ni,nk))
   om_bc_north=np.zeros((ni,nk))
   xwall_s=0.5*(x2d[0:-1,0]+x2d[1:,0])
   ywall_s=0.5*(y2d[0:-1,0]+y2d[1:,0])
   dist2_s=(yp2d[:,0]-ywall_s)**2+(xp2d[:,0]-xwall_s)**2
   om_bc_south=6*viscos/0.075/dist2_s

# make it 2D
   om_bc_south=np.repeat(om_bc_south[:,None], repeats=nk, axis=1)

   xwall_n=0.5*(x2d[0:-1,-1]+x2d[1:,-1])
   ywall_n=0.5*(y2d[0:-1,-1]+y2d[1:,-1])
   dist2_n=(yp2d[:,-1]-ywall_n)**2+(xp2d[:,-1]-xwall_n)**2
   om_bc_north=6*viscos/0.075/dist2_n

# make it 2D
   om_bc_north=np.repeat(om_bc_north[:,None], repeats=nk, axis=1)
   om_bc_z=0

   om_bc_west_type='d'
   om_bc_east_type='d'
   om_bc_south_type='d'
   om_bc_north_type='d'
   om_bc_z_type='n'


   if cyclic_x:
      u_bc_west_type='n'
      u_bc_east_type='n'
      v_bc_west_type='n'
      v_bc_east_type='n'
      w_bc_west_type='n'
      w_bc_east_type='n'
      p_bc_west_type='n'
      p_bc_east_type='n'
      k_bc_west_type='n'
      k_bc_east_type='n'
      eps_bc_west_type='n'
      eps_bc_east_type='n'
      om_bc_west_type='n'
      om_bc_east_type='n'

   if cyclic_z:
      u_bc_z_type='n'
      v_bc_z_type='n'
      w_bc_z_type='n'
      p_bc_z_type='n'
      k_bc_z_type='n'
      eps_bc_z_type='n'
      om_bc_z_type='n'


# scale residuals
   resnorm_p=uin*zmax*y2d[1,-1]
   resnorm_vel=uin**2*zmax*y2d[1,-1]

   return \
   acrank,acrank_conv, acrank_conv_keps, acrank_conv_kom, c_eps_1, c_eps_2, c_omega_1, c_omega_2, cmu, \
   convergence_limit_eps, convergence_limit_k, convergence_limit_om, convergence_limit_p, convergence_limit_vel, \
   cyclic_x, cyclic_z, dt, dz, eps_bc_east, eps_bc_east_type, eps_bc_north, eps_bc_north_type, eps_bc_south, \
   eps_bc_south_type, eps_bc_west, eps_bc_west_type, eps_bc_z, eps_bc_z_type, imon, itstep_average, \
   itstep_save, itstep_start,jmirror_synt,jmon,k_bc_east,k_bc_east_type,k_bc_north,k_bc_north_type,k_bc_south,k_bc_south_type, \
   k_bc_west, k_bc_west_type,k_bc_z,k_bc_z_type, keps,kmon,kom, kom_des,L_t_synt,maxit, ni,nj,nk,nmodes_synt,nsweep_keps, \
   nsweep_kom, nsweep_p, nsweep_vel, ntstep, om_bc_east, om_bc_east_type, om_bc_north, om_bc_north_type, \
   om_bc_south, om_bc_south_type, om_bc_west, om_bc_west_type, om_bc_z, om_bc_z_type, p_bc_east, p_bc_east_type, \
   p_bc_north, p_bc_north_type, p_bc_south, p_bc_south_type, p_bc_west, p_bc_west_type, p_bc_z, p_bc_z_type, pans, \
   prand_eps, prand_k, prand_omega, resnorm_p, resnorm_vel, restart, save, scheme, scheme_keps, smag, sormax, \
   u_bc_east, u_bc_east_type, u_bc_north, u_bc_north_type, u_bc_south, u_bc_south_type, u_bc_west, u_bc_west_type, \
   u_bc_z, u_bc_z_type, urfvis, v_bc_east, v_bc_east_type, v_bc_north, v_bc_north_type, v_bc_south, v_bc_south_type, \
   v_bc_west, v_bc_west_type, v_bc_z, v_bc_z_type, viscos, w_bc_east, w_bc_east_type, w_bc_north, w_bc_north_type, \
   w_bc_south, w_bc_south_type, w_bc_west, w_bc_west_type, w_bc_z, w_bc_z_type, wale, x2d, xp2d, y2d, yp2d, zmax
#/usr/bin/time -o out python -u py-3D-les-des-clean.py > out
# to empty the buffet, use option -u
from scipy import sparse
import numpy as np
import sys 
import time
import pyamg
from scipy.sparse import spdiags,linalg,eye

def init():
   import numpy as np
   import sys

   global acrank, acrank_conv, acrank_conv_keps, acrank_conv_kom, ae_bound, an_bound, areas, areasx, areasy, areaw, areawx, \
     areawy, areaz, as_bound, aw_bound, az_bound, c_eps_1, c_eps_2, c_omega_1, c_omega_2, cmu, convergence_limit_eps, \
     convergence_limit_k, convergence_limit_om, convergence_limit_p, convergence_limit_vel, cyclic_x, cyclic_z, delta_max,\
     deltae, deltan, dist3d, dt, dz, eps_bc_east, eps_bc_east_type, eps_bc_north, eps_bc_north_type, eps_bc_south, \
     eps_bc_south_type, eps_bc_west, eps_bc_west_type, eps_bc_z, eps_bc_z_type, fx, fy, imon, itstep_average,itstep_average_counter,\
     itstep_save,itstep_start,jmirror_synt,jmon,k_bc_east,k_bc_east_type,k_bc_north,k_bc_north_type,k_bc_south,k_bc_south_type,\
     k_bc_west,k_bc_west_type,k_bc_z,k_bc_z_type,keps,kmon,kom,kom_des,L_t_synt,maxit,ni,nj,nk,nmodes_synt,nsweep_keps, \
     nsweep_kom, nsweep_p, nsweep_t, nsweep_vel, ntstep, om_bc_east, om_bc_east_type, om_bc_north, om_bc_north_type, \
     om_bc_south, om_bc_south_type, om_bc_west, om_bc_west_type, om_bc_z, om_bc_z_type, p_bc_east, p_bc_east_type, \
     p_bc_north, p_bc_north_type, p_bc_south, p_bc_south_type, p_bc_west, p_bc_west_type, p_bc_z, p_bc_z_type, pans, \
     prand_eps, prand_k, prand_omega, residual_p, residual_u, residual_v, resnorm_p, resnorm_vel, restart, save, scheme, \
     scheme_keps, smag, sormax, u_bc_east, u_bc_east_type, u_bc_north, u_bc_north_type, u_bc_south, u_bc_south_type, u_bc_west,\
     u_bc_west_type, u_bc_z, u_bc_z_type, urfvis, v_bc_east, v_bc_east_type, v_bc_north, v_bc_north_type, v_bc_south, \
     v_bc_south_type, v_bc_west, v_bc_west_type, v_bc_z, v_bc_z_type, vis3d, viscos, vol, w_bc_east, w_bc_east_type, \
     w_bc_north, w_bc_north_type, w_bc_south, w_bc_south_type, w_bc_west, w_bc_west_type, w_bc_z, w_bc_z_type, wale, x, \
     x2d, xp2d, y, y2d, yp2d, zmax

   acrank,acrank_conv, acrank_conv_keps, acrank_conv_kom, c_eps_1, c_eps_2, c_omega_1, c_omega_2, cmu, \
   convergence_limit_eps, convergence_limit_k, convergence_limit_om, convergence_limit_p, convergence_limit_vel, \
   cyclic_x, cyclic_z, dt, dz, eps_bc_east, eps_bc_east_type, eps_bc_north, eps_bc_north_type, eps_bc_south, \
   eps_bc_south_type, eps_bc_west, eps_bc_west_type, eps_bc_z, eps_bc_z_type, imon, itstep_average, \
   itstep_save,itstep_start,jmirror_synt,jmon,k_bc_east,k_bc_east_type,k_bc_north,k_bc_north_type,k_bc_south,k_bc_south_type, \
   k_bc_west, k_bc_west_type, k_bc_z, k_bc_z_type,keps,kmon, kom,kom_des, L_t_synt,maxit, ni,nj,nk,nmodes_synt,nsweep_keps, \
   nsweep_kom, nsweep_p, nsweep_vel, ntstep, om_bc_east, om_bc_east_type, om_bc_north, om_bc_north_type, \
   om_bc_south, om_bc_south_type, om_bc_west, om_bc_west_type, om_bc_z, om_bc_z_type, p_bc_east, p_bc_east_type, \
   p_bc_north, p_bc_north_type, p_bc_south, p_bc_south_type, p_bc_west, p_bc_west_type, p_bc_z, p_bc_z_type, pans, \
   prand_eps, prand_k, prand_omega, resnorm_p, resnorm_vel, restart, save, scheme, scheme_keps, smag, sormax, \
   u_bc_east, u_bc_east_type, u_bc_north, u_bc_north_type, u_bc_south, u_bc_south_type, u_bc_west, u_bc_west_type, \
   u_bc_z, u_bc_z_type, urfvis, v_bc_east, v_bc_east_type, v_bc_north, v_bc_north_type, v_bc_south, v_bc_south_type, \
   v_bc_west, v_bc_west_type, v_bc_z, v_bc_z_type, viscos, w_bc_east, w_bc_east_type, w_bc_north, w_bc_north_type, \
   w_bc_south, w_bc_south_type, w_bc_west, w_bc_west_type, w_bc_z, w_bc_z_type, wale, x2d, xp2d, y2d, yp2d, zmax = \
    case_setup()


   ywall_s=0.5*(y2d[0:-1,0]+y2d[1:,0])
   dist_s=yp2d-ywall_s[:,None]
   ywall_n=0.5*(y2d[0:-1,-1]+y2d[1:,-1])
   dist_n=ywall_n[:,None] -yp2d
   dist=np.minimum(dist_s,dist_n)
   dist3d=dist[:,:,None]

# west face coordinate
   xw=0.5*(x2d[0:-1,0:-1]+x2d[0:-1,1:])
   yw=0.5*(y2d[0:-1,0:-1]+y2d[0:-1,1:])


   del1x=((xw-xp2d)**2+(yw-yp2d)**2)**0.5
   del2x=((xw-np.roll(xp2d,1,axis=0))**2+(yw-np.roll(yp2d,1,axis=0))**2)**0.5
   fx=del2x/(del1x+del2x)
   fx = np.dstack([fx]*nk)

   if cyclic_x:
     fx[0,:,:]=0.5

   xs=0.5*(x2d[0:-1,0:-1]+x2d[1:,0:-1])
   ys=0.5*(y2d[0:-1,0:-1]+y2d[1:,0:-1])

   del1y=((xs-xp2d)**2+(ys-yp2d)**2)**0.5
   del2y=((xs-np.roll(xp2d,1,axis=1))**2+(ys-np.roll(yp2d,1,axis=1))**2)**0.5
   fy=del2y/(del1y+del2y)
   fy = np.dstack([fy]*nk)

   areawy=np.diff(x2d,axis=1)*dz
   areawx=-np.diff(y2d,axis=1)*dz

# make them 3d
   areawx= np.dstack([areawx]*nk)
   areawy= np.dstack([areawy]*nk)

   areasy=-np.diff(x2d,axis=0)*dz
   areasx=np.diff(y2d,axis=0)*dz
# make them 3d
   areasx= np.dstack([areasx]*nk)
   areasy= np.dstack([areasy]*nk)

#  areaz=np.zeros((ni,nj,nk+1))

   areaw=(areawx**2+areawy**2)**0.5
   areas=(areasx**2+areasy**2)**0.5

   deltax=np.diff(xp2d,axis=0)
   deltay=np.diff(yp2d,axis=0)
   deltae=(deltax**2+deltay**2)**0.5
# duplicate last row and put it at the end
   deltae=np.insert(deltae,-1,deltae[-1],axis=0)
# make it 3d
   deltae= np.dstack([deltae]*nk)

   deltax=np.diff(xp2d,axis=1)
   deltay=np.diff(yp2d,axis=1)
   deltan=(deltax**2+deltay**2)**0.5
# duplicate last column and put it at the end
   deltan=np.insert(deltan,-1,deltan[:,-1],axis=1)
# make it 3d
   deltan= np.dstack([deltan]*nk)


# volume approaximated as the vector product of two triangles for cells
   ax=np.diff(x2d,axis=1)
   ay=np.diff(y2d,axis=1)
   bx=np.diff(x2d,axis=0)
   by=np.diff(y2d,axis=0)

   areaz_1=0.5*np.absolute(ax[0:-1,:]*by[:,0:-1]-ay[0:-1,:]*bx[:,0:-1])

   ax=np.diff(x2d,axis=1)
   ay=np.diff(y2d,axis=1)
   areaz_2=0.5*np.absolute(ax[1:,:]*by[:,0:-1]-ay[1:,:]*bx[:,0:-1])

   areaz=areaz_1+areaz_2
# make it 3d
   vol=areaz*dz
   vol= np.dstack([vol]*nk)
#  vol= np.loadtxt("vol.dat")
#  vol=np.reshape(vol,(ni,nj))
#  vol= np.dstack([vol]*nk)

# make it 3d
   areaz= np.dstack([areaz]*(nk+1))

# coeff at south wall (without viscosity)
   as_bound=areas[:,0,:]**2/(0.5*vol[:,0,:])

# coeff at north wall (without viscosity)
   an_bound=areas[:,-1,:]**2/(0.5*vol[:,-1,:])

# coeff at west wall (without viscosity)
   aw_bound=areaw[0,:,:]**2/(0.5*vol[0,:,:])
   if cyclic_x:
      aw_bound=areaw[0,:,:]**2/(0.5*(vol[0,:,:]+vol[-1,:,:]))

# coeff at east wall (without viscosity) N.B: which cyclic_x
# this is never used 
   ae_bound=areaw[-1,:,:]**2/(0.5*vol[-1,:,:])

# make it 2d
   az_bound=areaz[:,:,0]/(0.5*dz)  # wall node located AT fhe boudary 

   return


def compute_face_phi(phi3d,phi_bc_west,phi_bc_east,phi_bc_south,phi_bc_north,phi_bc_z,\
    phi_bc_west_type,phi_bc_east_type,phi_bc_south_type,phi_bc_north_type,phi_bc_z_type):
   import numpy as np

   phi3d_face_w=np.empty((ni+1,nj,nk))
   phi3d_face_s=np.empty((ni,nj+1,nk))
   phi3d_face_l=np.empty((ni,nj,nk+1))
#  phi3d_face_w[0:-1,:,:]=fx*phi3d+(1-fx)*np.roll(phi3d,1,axis=0)
#  phi3d_face_s[:,0:-1,:]=fy*phi3d+(1-fy)*np.roll(phi3d,1,axis=1)
#  phi3d_face_l[:,:,0:-1]=0.5*np.roll(phi3d,-1,axis=2)+0.5*phi3d
#  phi3d_face_w[1:,:,:]=fx*phi3d+(1-fx)*np.roll(phi3d,1,axis=0)
#  phi3d_face_s[:,1:,:]=fx*phi3d+(1-fx)*np.roll(phi3d,1,axis=1)
#  phi3d_face_l[:,:,1:]=0.5*np.roll(phi3d,1,axis=2)+0.5*phi3d
# phi_w[1,1,0]= fx[1,1,0]*p3d[1,1,0]+(1-fx[1,1,0])*p3d[0,1,0]
   phi3d_face_w[0:-1,:,:]=fx*phi3d+(1-fx)*np.roll(phi3d,1,axis=0)
   phi3d_face_s[:,0:-1,:]=fy*phi3d+(1-fy)*np.roll(phi3d,1,axis=1)
   phi3d_face_l[:,:,0:-1]=0.5*np.roll(phi3d,1,axis=2)+0.5*phi3d


# west boundary 
   phi3d_face_w[0,:,:]=phi_bc_west
   if phi_bc_west_type == 'n': 
# neumann
      phi3d_face_w[0,:,:]=phi3d[0,:,:]
   if cyclic_x:
      phi3d_face_w[0,:,:]=0.5*(phi3d[0,:,:]+phi3d[-1,:,:])

# east boundary 
   phi3d_face_w[-1,:,:]=phi_bc_east
   if phi_bc_east_type == 'n': 
# neumann
      phi3d_face_w[-1,:,:]=phi3d[-1,:,:]
   if cyclic_x:
      phi3d_face_w[-1,:,:]=0.5*(phi3d[0,:,:]+phi3d[-1,:,:])

# south boundary 
   phi3d_face_s[:,0,:]=phi_bc_south
   if phi_bc_south_type == 'n': 
# neumann
      phi3d_face_s[:,0,:]=phi3d[:,0,:]

# north boundary 
   phi3d_face_s[:,-1,:]=phi_bc_north
   if phi_bc_north_type == 'n': 
# neumann
      phi3d_face_s[:,-1,:]=phi3d[:,-1,:]

# low boundary 
   phi3d_face_l[:,:,0]=phi_bc_z
# high boundary 
   phi3d_face_l[:,:,-1]=phi_bc_z
   if phi_bc_z_type == 'n': 
# neumann
# low boundary 
      phi3d_face_l[:,:,0]= phi3d[:,:,0]
# high boundary 
      phi3d_face_l[:,:,-1]= phi3d[:,:,-1]
   if cyclic_z:
# low boundary 
      phi3d_face_l[:,:,0]= 0.5*(phi3d[:,:,-1]+phi3d[:,:,0])
# high boundary 
      phi3d_face_l[:,:,-1]= 0.5*(phi3d[:,:,-1]+phi3d[:,:,0])
   
   return phi3d_face_w,phi3d_face_s,phi3d_face_l

def dphidx(phi_face_w,phi_face_s):

   phi_w=phi_face_w[0:-1,:,:]*areawx[0:-1,:,:]
   phi_e=-phi_face_w[1:,:,:]*areawx[1:,:,:]
   phi_s=phi_face_s[:,0:-1,:]*areasx[:,0:-1,:]
   phi_n=-phi_face_s[:,1:,:]*areasx[:,1:,:]
   return (phi_w+phi_e+phi_s+phi_n)/vol

def dphidy(phi_face_w,phi_face_s):

   phi_w=phi_face_w[0:-1,:,:]*areawy[0:-1,:,:]
   phi_e=-phi_face_w[1:,:,:]*areawy[1:,:,:]
   phi_s=phi_face_s[:,0:-1,:]*areasy[:,0:-1,:]
   phi_n=-phi_face_s[:,1:,:]*areasy[:,1:,:]
   return (phi_w+phi_e+phi_s+phi_n)/vol

def dphidz(phi_face_l):

   phi_l=phi_face_l[:,:,0:-1]
   phi_h=phi_face_l[:,:,1:]
   return (phi_h-phi_l)/dz

def coeff(convw,convs,convl,vis3d,prand,scheme_local):

   visw=np.zeros((ni+1,nj,nk))
   viss=np.zeros((ni,nj+1,nk))
   visl=np.zeros((ni,nj,nk+1))
   if prand > 0:
      vis_turb=(vis3d-viscos)/prand
   elif pans: # k and eps in PANS
      vis_turb=(vis3d-viscos)/np.abs(prand)/fk3d**2
      

   visw[0:-1,:,:]=fx*vis_turb+(1-fx)*np.roll(vis_turb,1,axis=0)+viscos
   viss[:,0:-1,:]=fy*vis_turb+(1-fy)*np.roll(vis_turb,1,axis=1)+viscos
   visl[:,:,0:-1]=0.5*vis_turb+0.5*np.roll(vis_turb,1,axis=2)+viscos


   if cyclic_z:
      visl[:,:,0]=0.5*(vis_turb[:,:,0]+vis_turb[:,:,-1])+viscos

   volw=np.ones((ni+1,nj,nk))*1e-10
   vols=np.ones((ni,nj+1,nk))*1e-10
   volw[1:,:,:]=0.5*np.roll(vol,-1,axis=0)+0.5*vol
   diffw=visw[0:-1,:,:]*areaw[0:-1,:,:]**2/volw[0:-1,:,:]
   vols[:,1:,:]=0.5*np.roll(vol,-1,axis=1)+0.5*vol
   diffs=viss[:,0:-1,:]*areas[:,0:-1,:]**2/vols[:,0:-1,:]
   diffl=visl[:,:,0:-1]*areaz[:,:,0:-1]/dz

   if cyclic_x:
      visw[0,:,:]=0.5*(vis_turb[0,:,:]+vis_turb[-1,:,:])+viscos
      diffw[0,:,:]=visw[0,:,:]*areaw[0,:,:]**2/(0.5*(vol[0,:,:]+vol[-1,:,:]))


   if scheme_local == 'h':
      if itstep == 0 and iter == 0:
         print('hybrid scheme, prand=',prand)

      aw3d=np.maximum(convw[0:-1,:,:],diffw+(1-fx)*convw[0:-1,:,:])
      aw3d=np.maximum(aw3d,0.)

      ae3d=np.maximum(-convw[1:,:,:],np.roll(diffw,-1,axis=0)-np.roll(fx,-1,axis=0)*convw[1:,:,:])
      ae3d=np.maximum(ae3d,0.)

      as3d=np.maximum(convs[:,0:-1,:],diffs+(1-fy)*convs[:,0:-1,:])
      as3d=np.maximum(as3d,0.)

      an3d=np.maximum(-convs[:,1:,:],np.roll(diffs,-1,axis=1)-np.roll(fy,-1,axis=1)*convs[:,1:,:])
      an3d=np.maximum(an3d,0.)

      al3d=np.maximum(convl[:,:,0:-1],diffl+0.5*convl[:,:,0:-1])
      al3d=np.maximum(al3d,0.)

      ah3d=np.maximum(-convl[:,:,1:],np.roll(diffl,-1,axis=2)-0.5*convl[:,:,1:])
      ah3d=np.maximum(ah3d,0.)

   if scheme_local == 'u':
      if itstep == 0 and iter == 0:
         print('upwind scheme, prand=',prand)

      aw3d=np.maximum(convw[0:-1,:,:],0)+diffw
      ae3d=np.maximum(-convw[1:,:,:],-0)+np.roll(diffw,-1,axis=0)
      as3d=np.maximum(convs[:,0:-1,:],0)+diffs
      an3d=np.maximum(-convs[:,1:,:],0)+np.roll(diffs,-1,axis=1)
      al3d=np.maximum(convl[:,:,0:-1],0)+diffl
      ah3d=np.maximum(-convl[:,:,1:],0)+np.roll(diffl,-1,axis=2)

   if scheme_local == 'c':
      if itstep == 0 and iter == 0:
         print('CDS scheme, prand=',prand)
      aw3d=diffw+(1-fx)*convw[0:-1,:,:]
      ae3d=np.roll(diffw,-1,axis=0)-np.roll(fx,-1,axis=0)*convw[1:,:,:]

      as3d=diffs+(1-fy)*convs[:,0:-1,:]
      an3d=np.roll(diffs,-1,axis=1)-np.roll(fy,-1,axis=1)*convs[:,1:,:]

      al3d=diffl+0.5*convl[:,:,0:-1]
      ah3d=np.roll(diffl,-1,axis=2)-0.5*convl[:,:,1:]

#     print('diffe',np.roll(diffw,-1,axis=0))
#     print('conve',convw[1:,1,0])
#     aa=np.roll(fx,-1,axis=0)
#     print('fx',aa[:,1,0])
#     print('ae3d',ae3d[:,1,0])
#     print('as3d',as3d[1,:,0])
#     print('an3d',an3d[1,:,0])

#     sys.exit()

   apo3d=vol/dt[itstep]


   if not cyclic_x:
      aw3d[0,:,:]=0
      ae3d[-1,:,:]=0
   as3d[:,0,:]=0
   an3d[:,-1,:]=0
   if not cyclic_z:
      al3d[:,:,0]=0
      ah3d[:,:,-1]=0

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d

def bc(su3d,sp3d,phi_bc_west,phi_bc_east,phi_bc_south,phi_bc_north,phi_bc_z\
     ,phi_bc_west_type,phi_bc_east_type,phi_bc_south_type,phi_bc_north_type,phi_bc_z_type):
   su3d=np.zeros((ni,nj,nk))
   sp3d=np.zeros((ni,nj,nk))

#south
   if phi_bc_south_type == 'd':
      sp3d[:,0,:]=sp3d[:,0,:]-viscos*as_bound
      su3d[:,0,:]=su3d[:,0,:]+viscos*as_bound*phi_bc_south

#north
   if phi_bc_north_type == 'd':
      sp3d[:,-1,:]=sp3d[:,-1,:]-viscos*an_bound
      su3d[:,-1,:]=su3d[:,-1,:]+viscos*an_bound*phi_bc_north

#west
   if phi_bc_west_type == 'd':
      sp3d[0,:,:]=sp3d[0,:,:]-viscos*aw_bound
      su3d[0,:,:]=su3d[0,:,:]+viscos*aw_bound*phi_bc_west
#east
   if phi_bc_east_type == 'd':
      sp3d[-1,:,:]=sp3d[-1,:,:]-viscos*ae_bound
      su3d[-1,:,:]=su3d[-1,:,:]+viscos*ae_bound*phi_bc_east

#low & high
   if phi_bc_z_type == 'd':
      sp3d[:,:,0]=sp3d[:,:,0]-viscos*az_bound
      sp3d[:,:,-1]=sp3d[:,:,-1]-viscos*az_bound
      su3d[:,:,0]=su3d[:,:,0]+viscos*az_bound*phi_bc_z
      su3d[:,:,-1]=su3d[:,:,-1]+viscos*az_bound*phi_bc_z

#  cyclic x
#  if cyclic_x:

   return su3d,sp3d

def conv(u3d,v3d,w3d,p3d_face_w,p3d_face_s,p3d_face_l):
#compute convection
   
   dtt=dt[itstep]*acrank
   u3d_star=u3d+dphidx(p3d_face_w,p3d_face_s)*dtt
   v3d_star=v3d+dphidy(p3d_face_w,p3d_face_s)*dtt
   w3d_star=w3d+dphidz(p3d_face_l)*dtt

   u3d_face_w,u3d_face_s,u3d_face_l=compute_face_phi(u3d_star,u_bc_west,u_bc_east,u_bc_south,u_bc_north,u_bc_z,\
    u_bc_west_type,u_bc_east_type,u_bc_south_type,u_bc_north_type,u_bc_z_type)
   v3d_face_w,v3d_face_s,v3d_face_l=compute_face_phi(v3d_star,v_bc_west,v_bc_east,v_bc_south,v_bc_north,v_bc_z,\
    v_bc_west_type,v_bc_east_type,v_bc_south_type,v_bc_north_type,v_bc_z_type)
   w3d_face_w,w3d_face_s,w3d_face_l=compute_face_phi(w3d_star,w_bc_west,w_bc_east,w_bc_south,w_bc_north,w_bc_z,\
    w_bc_west_type,w_bc_east_type,w_bc_south_type,w_bc_north_type,w_bc_z_type)

   convw=-u3d_face_w*areawx-v3d_face_w*areawy
   convs=-u3d_face_s*areasx-v3d_face_s*areasy
   convl=w3d_face_l*areaz

   return convw,convs,convl
   
def solve_3d(phi3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,tol_conv,nmax,acrank_conv_local):
   if itstep == 0 and iter == 0:
      print('solve_3d called')
      print('nmax,acrank_conv_local',nmax,acrank_conv_local)

   aw=np.matrix.flatten(aw3d)*acrank_conv_local
   ae=np.matrix.flatten(ae3d)*acrank_conv_local
   as1=np.matrix.flatten(as3d)*acrank_conv_local
   an=np.matrix.flatten(an3d)*acrank_conv_local
   al=np.matrix.flatten(al3d)*acrank_conv_local
   ah=np.matrix.flatten(ah3d)*acrank_conv_local
   ap=np.matrix.flatten(ap3d)
  
   m=ni*nj*nk

   if cyclic_x and cyclic_z:
      al_cyc=np.zeros(m)
      al_cyc[0:-1:nk]= al[0:-1:nk]
      al[0:-1:nk]=0
      ah_cyc=np.zeros(m)
      ah_cyc[nk-1::nk]=ah[nk-1::nk]
      ah[nk-1:-1:nk]=0
      ah[-1]=0
      A = sparse.diags([ap, -ah[:-1], -al[1:], -al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
            [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csc') 
   elif not cyclic_z and cyclic_x:
      A = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
            [0, 1, -1, nk,-nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csc') 
   elif cyclic_z and not cyclic_x:
      al_cyc=np.zeros(m)
      al_cyc[0:-1:nk]= al[0:-1:nk]
      al[0:-1:nk]=0
      ah_cyc=np.zeros(m)
      ah_cyc[nk-1:-1:nk]=ah[nk-1:-1:nk]
      ah[nk-1:-1:nk]=0
      A = sparse.diags([ap, -ah[:-1], -al[1:], -al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], \
            [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj], format='csc') 
   elif not cyclic_z and not cyclic_x:
      A = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], [0, 1, -1, nk, -nk, nk*nj, -nk*nj], format='csc') 

   su=np.matrix.flatten(su3d)
   phi=np.matrix.flatten(phi3d)
#  phi,info=linalg.gmres(A,su,x0=phi, tol=tol_conv,  maxiter=nmax)  # good
   phi,info=linalg.lgmres(A,su,x0=phi, atol=tol_conv, tol=tol_conv,  maxiter=nmax)  # good
# check residual
   resid=np.linalg.norm(A*phi - su)

   phi3d=np.reshape(phi,(ni,nj,nk))

   return phi3d,resid


def solve_p(phi3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,tol_conv):
   global Ap,Mp
   if itstep == 0 and iter == 0:
      print('solve_p called')

   if iter == 0 and itstep == 0:
      print('A and M computed,tol_conv=',tol_conv)
      aw=np.matrix.flatten(aw3d)
      ae=np.matrix.flatten(ae3d)
      as1=np.matrix.flatten(as3d)
      an=np.matrix.flatten(an3d)
      al=np.matrix.flatten(al3d)
      ah=np.matrix.flatten(ah3d)
      ap=np.matrix.flatten(ap3d)

      m=ni*nj*nk



      if cyclic_x and cyclic_z:
         al_cyc=np.zeros(m)
         al_cyc[0:-1:nk]= al[0:-1:nk]
         al[0:-1:nk]=0
         ah_cyc=np.zeros(m)
         ah_cyc[nk-1::nk]=ah[nk-1::nk]
         ah[nk-1:-1:nk]=0
         ah[-1]=0
         Ap = sparse.diags([ap, -ah[:-1], -al[1:], -al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
            [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csc')
      elif not cyclic_z and cyclic_x:
         Ap = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
               [0, 1, -1, nk,-nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csc')
      elif cyclic_z and not cyclic_x:
         al_cyc=np.zeros(m)
         al_cyc[0:-1:nk]= al[0:-1:nk]
         al[0:-1:nk]=0
         ah_cyc=np.zeros(m)
         ah_cyc[nk-1:-1:nk]=ah[nk-1:-1:nk]
         ah[nk-1:-1:nk]=0
         Ap = sparse.diags([ap, -ah[:-1], -al[1:], -al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], \
            [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj], format='csc') 
      elif not cyclic_z and not cyclic_x:
         Ap = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], [0, 1, -1, nk, -nk, nk*nj, -nk*nj], format='csc')
   


      Ap = pyamg.ruge_stuben_solver(Ap)                    # construct the multigrid hierarchy

#     P = sparse.linalg.spilu(Ap)
# create the AMG hierarchy
#     ml = smoothed_aggregation_solver(Ap)
#     Mp = ml.aspreconditioner()

#     m=ni*nj*nk
#     Mp = P.L * P.U
#     Mp = sparse.linalg.LinearOperator((m,m), P.solve)

   print('in solve_p')

   phi=np.matrix.flatten(phi3d)
   su=np.matrix.flatten(su3d)
   phi = Ap.solve(su, tol=tol_conv, x0=phi)

   phi3d=np.reshape(phi,(ni,nj,nk))

   return phi3d


def calcu(su3d,sp3d,dpdx_old,p3d_face_w,p3d_face_s):
   if itstep == 0 and iter == 0:
      print('calcu called')
# b.c., sources, coefficients

# presssure gradient
   dpdx=acrank*dphidx(p3d_face_w,p3d_face_s)+(1-acrank)*dpdx_old
   su3d=su3d-dpdx*vol

# modify su & sp
   su3d,sp3d=modify_u(su3d,sp3d)
# unsteady term added in crank_nicol

   return su3d,sp3d

def calcv(su3d,sp3d,dpdy_old,p3d_face_w,p3d_face_s):
   if itstep == 0 and iter == 0:
      print('calcv called')
# b.c., sources, coefficients 

# presssure gradient
   dpdy=acrank*dphidy(p3d_face_w,p3d_face_s)+(1-acrank)*dpdy_old
   su3d=su3d-dpdy*vol

# modify su & sp
   su3d,sp3d=modify_v(su3d,sp3d)
# unsteady term added in crank_nicol

   return su3d,sp3d

def compute_fk(k3d,eps3d):

   if itstep == 0 and iter == 0:
      print('compute_fk called')

   L_t=k3d**1.5/eps3d
   cdes=0.67
   psi=np.maximum(1,L_t/(cdes*delta_max))

   fkmin=0.2
   fk3d=np.maximum(1.-(psi-1.)/(c_eps_2-c_eps_1),fkmin)

   return fk3d


def calck_kom(su3d,sp3d,k3d,om3d,vis3d,u3d_face_w,u3d_face_s,v3d_face_w,v3d_face_s,w3d_face_l):
# b.c., sources, coefficients 
   if itstep == 0 and iter == 0:
      print('calck_kom called')

# production term
   dudx=dphidx(u3d_face_w,u3d_face_s)
   dvdx=dphidx(v3d_face_w,v3d_face_s)
   dwdx=dphidx(w3d_face_w,w3d_face_s)

   dudy=dphidy(u3d_face_w,u3d_face_s)
   dvdy=dphidy(v3d_face_w,v3d_face_s)
   dwdy=dphidy(w3d_face_w,w3d_face_s)

   dudz=dphidz(u3d_face_l)
   dvdz=dphidz(v3d_face_l)
   dwdz=dphidz(w3d_face_l)

   gen= (2.*(dudx**2+dvdy**2+dwdz**2)+(dudz+dwdx)**2+(dvdz+dwdy)**2+(dudy+dvdx)**2)
   vist=np.maximum(vis3d-viscos,1e-10)
   su3d=su3d+vist*gen*vol

   rl=k3d**0.5/(cmu*om3d)

   if kom_des:
      fk3d=np.maximum(1.,rl/(0.67*delta_max))
   else:
      fk3d=1

# dissipation term
   sp3d=sp3d-fk3d*cmu*om3d*vol

# modify su & sp
   su3d,sp3d=modify_w(su3d,sp3d)

# unsteady term added in crank_nicol

   return su3d,sp3d,gen,fk3d

def calcom(su3d,sp3d,om3d,gen):
   if itstep == 0 and iter == 0:
      print('calcom called')


#--------production term
   su3d=su3d+c_omega_1*gen*vol

#--------dissipation term
   sp3d=sp3d-c_omega_2*om3d*vol

# modify su & sp
   su3d,sp3d=modify_om(su3d,sp3d)

   return su3d,sp3d

def calck_ls(su3d,sp3d,k3d,eps3d,vis3d,u3d_face_w,u3d_face_s,v3d_face_w,v3d_face_s,w3d_face_l):
# b.c., sources, coefficients 
   if itstep == 0 and iter == 0:
      print('calck_ls called')

# production term
   dudx=dphidx(u3d_face_w,u3d_face_s)
   dvdx=dphidx(v3d_face_w,v3d_face_s)
   dwdx=dphidx(w3d_face_w,w3d_face_s)

   dudy=dphidy(u3d_face_w,u3d_face_s)
   dvdy=dphidy(v3d_face_w,v3d_face_s)
   dwdy=dphidy(w3d_face_w,w3d_face_s)

   dudz=dphidz(u3d_face_l)
   dvdz=dphidz(v3d_face_l)
   dwdz=dphidz(w3d_face_l)

   gen= (2.*(dudx**2+dvdy**2+dwdz**2)+(dudz+dwdx)**2+(dvdz+dwdy)**2+(dudy+dvdx)**2)
   vist=np.maximum(vis3d-viscos,1e-10)
   su3d=su3d+vist*gen*vol

# dissipation term
   sp3d=sp3d-eps3d/k3d*vol

# D term
# compute gradient of k**0.5
   k05=k3d**0.5
   k05_face_w,k05_face_s,k05_face_l=compute_face_phi(k05,k_bc_west,k_bc_east,k_bc_south,k_bc_north,k_bc_z,\
     k_bc_west_type,k_bc_east_type,k_bc_south_type,k_bc_north_type,k_bc_z_type)
   dk05dy=dphidy(k05_face_w,k05_face_s)
   dterm=2.*viscos*dk05dy**2
   sp3d=sp3d-dterm/k3d*vol

# modify su & sp
   su3d,sp3d=modify_k(su3d,sp3d)

# unsteady term added in crank_nicol

   return su3d,sp3d,gen,dudx,dudy

def calceps_ls(su3d,sp3d,k3d,eps3d,vis3d,gen,dudx,dudy):
   if itstep == 0 and iter == 0:
      print('calceps_ls called')

# b.c., sources, coefficients 
   rt=k3d**2/eps3d/viscos
   fdampf2=1.-0.3*np.exp(-rt**2)
   fmu3d=np.exp(-3.4/(1.+rt/50.)**2)
   fmu3d=np.minimum(fmu3d,1.)

#--------production term
   su3d=su3d+c_eps_1*cmu*fmu3d*gen*k3d*vol
   c2u=c_eps_1+fk3d*(fdampf2*c_eps_2-c_eps_1)

#--------dissipation term
   sp3d=sp3d-c2u*eps3d*vol/k3d

#--- E term (note that u_bc_west,u_bc_east ... are not used since Neumann bc are prescribed) 
   dudy_face_w,dudy_face_s,dudy_face_l=compute_face_phi(dudy,u_bc_west,u_bc_east,u_bc_south,u_bc_north,u_bc_z,\
     'n','n','n','n','n')
   dudx_face_w,dudx_face_s,dudx_face_l=compute_face_phi(dudx,u_bc_west,u_bc_east,u_bc_south,u_bc_north,u_bc_z,\
     'n','n','n','n','n')
   d2udy2=dphidy(dudy_face_w,dudy_face_s)
   d2udx2=dphidx(dudx_face_w,dudx_face_s)
   vist=vis3d-viscos
   eterm=2.*viscos*vist*(d2udx2**2+d2udy2**2)
#  eterm=2.*viscos*vist*d2udy2**2
   su3d=su3d+eterm*vol

# modify su & sp
   su3d,sp3d=modify_eps(su3d,sp3d)

   return su3d,sp3d,fmu3d


def calcw(su3d,sp3d,dpdz_old,p3d_face_l):
# b.c., sources, coefficients 
   if itstep == 0 and iter == 0:
      print('calcw called')

# presssure gradient
   dpdz=acrank*dphidz(p3d_face_l)+(1-acrank)*dpdz_old
   su3d=su3d-dpdz*vol

# modify su & sp
   su3d,sp3d=modify_w(su3d,sp3d)

# unsteady term added in crank_nicol
   return su3d,sp3d

def calcp(convw,convs,convl):
   if itstep == 0 and iter == 0:
      print('calcp called')
# b.c., sources, coefficients
   volw=np.ones((ni+1,nj,nk))*1e-10
   vols=np.ones((ni,nj+1,nk))*1e-10
   volw[1:,:,:]=0.5*np.roll(vol,-1,axis=0)+0.5*vol
   aw3d=areaw[0:-1,:,:]**2/volw[0:-1,:,:]
   vols[:,1:,:]=0.5*np.roll(vol,-1,axis=1)+0.5*vol
   as3d=areas[:,0:-1,:]**2/vols[:,0:-1,:]
#  aw3d=areaw[0:-1,:,:]/np.roll(deltae,1,axis=0)
#  as3d=areas[:,0:-1,:]/np.roll(deltan,1,axis=1)
   al3d=areaz[:,:,0:-1]/dz

   ae3d=np.roll(aw3d,-1,axis=0)
   an3d=np.roll(as3d,-1,axis=1)
   ah3d=np.roll(al3d,-1,axis=2)


   if cyclic_x:
      aw3d[0,:,:]=areaw[0,:,:]**2/(0.5*(vol[0,:,:]+vol[-1,:,:]))
      ae3d[-1,:,:]=aw3d[0,:,:]
   else:
      aw3d[0,:,:]=0
      ae3d[-1,:,:]=0
   
   if not cyclic_z:
      al3d[:,:,0]=0
      ah3d[:,:,-1]=0

   as3d[:,0,:]=0
   an3d[:,-1,:]=0



   ap3d=aw3d+ae3d+as3d+an3d+al3d+ah3d

# set p3d=0 in [1,0,0] to make it non-singular
#  as3d[2,2,2]=0
#  an3d[2,2,2]=0
#  aw3d[2,2,2]=0
#  ae3d[2,2,2]=0
#  al3d[2,2,2]=0
#  ah3d[2,2,2]=0
   ap3d[0,0,0]=1e5
#  su3d[2,2,2]=0

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d


def correct_conv(u3d,v3d,w3d,p3d,aw3d_p,as3d_p,al3d_p):
# correct convections
# create ghost cells at east & west boundaries with Neumann b.c.
   p3d_w=p3d
   p3d_s=p3d
   p3d_l=p3d
   dtt=dt[itstep]*acrank
#\\\\\\\\\\\\\ west face
# set zeros and put if before row 0
   p3d_w=np.insert(p3d_w,0,np.zeros((nj,nk)),axis=0)
   if cyclic_x:
      convw[1:-1,:,:]=convw[1:-1,:,:]+aw3d_p[1:,:,:]*(p3d_w[1:-1,:,:]-p3d[1:,:,:])*dtt
#     convw[-1,:,:]=convw[-1,:,:]+aw3d_p[-1,:,:]*(p3d[-1,:,:]-p3d[0,:,:])*dtt
#     convw[0,:,:]=convw[-1,:,:]
      convw[0,:,:]=convw[0,:,:]+aw3d_p[0,:,:]*(p3d[-1,:,:]-p3d[0,:,:])*dtt
      convw[-1,:,:]=convw[0,:,:]
   else:
      convw[0:-1,:,:]=convw[0:-1,:,:]+aw3d_p*(p3d_w[0:-1,:,:]-p3d)*dtt


#\\\\\\\\\\\\\ south face
# set zeros and put it before column 0
   p3d_s=np.insert(p3d_s,0,np.zeros((ni,nk)),axis=1)
   convs[:,0:-1,:]=convs[:,0:-1,:]+as3d_p*(p3d_s[:,0:-1,:]-p3d)*dtt

#\\\\\\\\\\\\\ low face
# set zeros and put it before column 0
   p3d_l=np.insert(p3d_l,0,np.zeros((ni,nj)),axis=2)
   if cyclic_z:
      convl[:,:,1:-1]=convl[:,:,1:-1]+al3d_p[:,:,1:]*(p3d_l[:,:,1:-1]-p3d[:,:,1:])*dtt
      convl[:,:,-1]=convl[:,:,-1]+al3d_p[:,:,-1]*(p3d[:,:,-1]-p3d[:,:,0])*dtt
      convl[:,:,0]=convl[:,:,-1]
   else:
      convl[:,:,0:-1]=convl[:,:,0:-1]+al3d_p*(p3d_l[:,:,0:-1]-p3d)*dtt

# continuity error
   su3d=convw[0:-1,:,:]-np.roll(convw[0:-1,:,:],-1,axis=0)+convs[:,0:-1,:]-np.roll(convs[:,0:-1,:],-1,axis=1)\
   +convl[:,:,0:-1]-np.roll(convl[:,:,0:-1],-1,axis=2)


   return convw,convs,convl,p3d,u3d,v3d,w3d,su3d


def update(u3d,v3d,w3d,k3d,eps3d,om3d,p3d_face_w,p3d_face_s,p3d_face_l):
    u3d_old=u3d
    v3d_old=v3d
    w3d_old=w3d
    k3d_old=k3d
    eps3d_old=eps3d
    om3d_old=om3d
    dpdx_old=dphidx(p3d_face_w,p3d_face_s)
    dpdy_old=dphidy(p3d_face_w,p3d_face_s)
    dpdz_old=dphidz(p3d_face_l)

    return u3d_old,v3d_old,w3d_old,k3d_old,eps3d_old,om3d_old,dpdx_old,dpdy_old,dpdz_old

def time_average(u3d_mean,v3d_mean,w3d_mean,p3d_mean,k3d_mean,eps3d_mean,om3d_mean,uu3d_stress,vv3d_stress,ww3d_stress,uv3d_stress,\
                fk3d_mean,vis3d_mean):

    global itstep_average_counter

    itstep_average_counter=itstep_average_counter+1
    u3d_mean=u3d_mean+u3d
    v3d_mean=v3d_mean+v3d
    w3d_mean=w3d_mean+w3d
    p3d_mean=p3d_mean+w3d
    k3d_mean=k3d_mean+k3d
    fk3d_mean=fk3d_mean+fk3d
    om3d_mean=om3d_mean+om3d
    eps3d_mean=eps3d_mean+eps3d
    vis3d_mean=vis3d_mean+vis3d
    uu3d_stress=uu3d_stress+u3d**2
    vv3d_stress=vv3d_stress+v3d**2
    ww3d_stress=ww3d_stress+w3d**2
    uv3d_stress=uv3d_stress+u3d*v3d
 
    return u3d_mean,v3d_mean,w3d_mean,p3d_mean,k3d_mean,eps3d_mean,om3d_mean,uu3d_stress,vv3d_stress,ww3d_stress,uv3d_stress,\
           fk3d_mean,vis3d_mean

def crank_nicol(phi3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv_local):
    ap3d=aw3d+ae3d+as3d+an3d+al3d+ah3d
    su3d=su3d+(apo3d-(1-acrank_conv_local)*ap3d)*phi3d_old
    ap3d=apo3d+acrank_conv_local*ap3d-sp3d
    su3d=su3d+(1-acrank_conv_local)*\
      (ae3d*np.roll(phi3d_old,-1,axis=0)+aw3d*np.roll(phi3d_old,1,axis=0) \
      +an3d*np.roll(phi3d_old,-1,axis=1)+as3d*np.roll(phi3d_old,1,axis=1) \
      +ah3d*np.roll(phi3d_old,-1,axis=2)+al3d*np.roll(phi3d_old,1,axis=2))
    return ap3d,su3d

def vist_kom(vis3d,k3d,om3d):
   if itstep == 0 and iter == 0:
      print('vist_kom called')

   visold= vis3d
   vis3d= k3d/om3d+viscos
#            under-relax viscosity
   vis3d= urfvis*vis3d+(1.-urfvis)*visold

   return vis3d

def vist_pans(vis3d,k3d,eps3d,fmu3d):
   if itstep == 0 and iter == 0:
      print('vist_pans called')

   visold= vis3d
   vis3d= cmu*fmu3d*k3d**2/eps3d+viscos
#            under-relax viscosity
   vis3d= urfvis*vis3d+(1.-urfvis)*visold

   return vis3d

def vist_smag(u3d_face_w,u3d_face_s,u3d_face_l,v3d_face_w,v3d_face_s,v3d_face_l,w3d_face_w,w3d_face_s,w3d_face_l,vis3d):
   if itstep == 0 and iter == 0:
      print('vist_smag called')
   dudx=dphidx(u3d_face_w,u3d_face_s)
   dvdx=dphidx(v3d_face_w,v3d_face_s)
   dwdx=dphidx(w3d_face_w,w3d_face_s)

   dudy=dphidy(u3d_face_w,u3d_face_s)
   dvdy=dphidy(v3d_face_w,v3d_face_s)
   dwdy=dphidy(w3d_face_w,w3d_face_s)

   dudz=dphidz(u3d_face_l)
   dvdz=dphidz(v3d_face_l)
   dwdz=dphidz(w3d_face_l)

   gen= (2.*(dudx**2+dvdy**2+dwdz**2)+(dudz+dwdx)**2+(dvdz+dwdy)**2+(dudy+dvdx)**2)

# RANS lengthscale
   rl_rans=0.41*np.minimum(yp2d,yp2d[1,-1]-yp2d)
# make it 3d
   rl_rans_3d= np.dstack([rl_rans]*nk)
   rl_les=cmu*vol**0.3333333
   rl=np.minimum(rl_rans_3d,rl_les)
   visold= vis3d
   vis3d= rl**2*gen**0.5+viscos
#            under-relax viscosity
   vis3d= urfvis*vis3d+(1.-urfvis)*visold
   return vis3d

def save_time_aver_data(u3d_mean,v3d_mean,w3d_mean,p3d_mean,eps3d_mean,om3d_mean,fk3d_mean,vis3d_mean,k3d_mean,uu3d_stress, \
                        vv3d_stress,ww3d_stress,uv3d_stress):

# save time-averaged data to disk
   np.save('u_averaged', np.mean(u3d_mean,axis=2))
   np.save('v_averaged', np.mean(v3d_mean,axis=2))
   np.save('w_averaged', np.mean(w3d_mean,axis=2))
   np.save('p_averaged', np.mean(p3d_mean,axis=2))
   np.save('k_averaged', np.mean(k3d_mean,axis=2))
   np.save('fk_averaged', np.mean(fk3d_mean,axis=2))
   np.save('k_averaged', np.mean(k3d_mean,axis=2))
   np.save('om_averaged', np.mean(om3d_mean,axis=2))
   np.save('vis_averaged', np.mean(vis3d_mean,axis=2))
   np.save('eps_averaged', np.mean(eps3d_mean,axis=2))
   np.save('k3d_averaged', np.mean(k3d_mean,axis=2))
   np.save('uu_stress', np.mean(uu3d_stress,axis=2))
   np.save('vv_stress', np.mean(vv3d_stress,axis=2))
   np.save('ww_stress', np.mean(ww3d_stress,axis=2))
   np.save('uv_stress', np.mean(uv3d_stress,axis=2))
   np.save('itstep',[itstep_average_counter,nk,dz])
   print('itstep_average_counter,nk,dz',itstep_average_counter,nk,dz)
 
   return



def read_restart_data(u3d,v3d,w3d,p3d,k3d,eps3d,om3d):

   u3d=np.load('u3d_saved.npy')
   v3d=np.load('v3d_saved.npy')
   w3d=np.load('w3d_saved.npy')
   p3d=np.load('p3d_saved.npy')
   if keps or pans:
      k3d=np.load('k3d_saved.npy')
      eps3d=np.load('eps3d_saved.npy')
   if kom or kom_des:
      k3d=np.load('k3d_saved.npy')
      om3d=np.load('om3d_saved.npy')

   return u3d,v3d,w3d,p3d,k3d,eps3d,om3d

def save_data(u3d,v3d,w3d,p3d,k3d,eps3d,om3d):

   np.save('u3d_saved', u3d)
   np.save('v3d_saved', v3d)
   np.save('w3d_saved', w3d)
   np.save('p3d_saved', p3d)
   if keps or pans:
      np.save('k3d_saved', k3d)
      np.save('eps3d_saved', eps3d)
   if kom or kom_des:
      np.save('k3d_saved', k3d)
      np.save('om3d_saved', om3d)

   return 

def vist_wale(u3d_face_w,u3d_face_s,u3d_face_l,v3d_face_w,v3d_face_s,v3d_face_l,w3d_face_w,w3d_face_s,w3d_face_l,vis3d):
   if itstep == 0 and iter == 0:
      print('vist_wale called')

   dudx=dphidx(u3d_face_w,u3d_face_s)
   dvdx=dphidx(v3d_face_w,v3d_face_s)
   dwdx=dphidx(w3d_face_w,w3d_face_s)

   dudy=dphidy(u3d_face_w,u3d_face_s)
   dvdy=dphidy(v3d_face_w,v3d_face_s)
   dwdy=dphidy(w3d_face_w,w3d_face_s)

   dudz=dphidz(u3d_face_l)
   dvdz=dphidz(v3d_face_l)
   dwdz=dphidz(w3d_face_l)

   s11=dudx
   s12=0.5*(dudy+dvdx)
   s13=0.5*(dudz+dwdx)

   s21=s12
   s22=dvdy
   s23=0.5*(dvdz+dwdy)

   s31=s13
   s32=s23
   s33=dwdz

   g11=dudx
   g12=dudy
   g13=dudz

   g21=dvdx
   g22=dvdy
   g23=dvdz
      
   g31=dwdx
   g32=dwdy
   g33=dwdz

#square of g_ij = g_ik g_kj
   g11_2=g11*g11+g12*g21+g13*g31
   g12_2=g11*g12+g12*g22+g13*g32
   g13_2=g11*g13+g12*g23+g13*g33

   g21_2=g21*g11+g22*g21+g23*g31
   g22_2=g21*g12+g22*g22+g23*g32
   g23_2=g21*g13+g22*g23+g23*g33

   g31_2=g31*g11+g32*g21+g33*g31
   g32_2=g31*g12+g32*g22+g33*g32
   g33_2=g31*g13+g32*g23+g33*g33

   gkk_2=(g11_2+g22_2+g33_2)/3.

   sd11=g11_2-gkk_2
   sd12=0.5*(g12_2+g21_2)
   sd13=0.5*(g13_2+g31_2)
   sd21=sd12
   sd22=g22_2-gkk_2
   sd23=0.5*(g23_2+g32_2)

   sd31=sd13
   sd32=sd23
   sd33=g33_2-gkk_2

   sijsij=s11*s11+s12*s12+s13*s13+\
          s21*s21+s22*s22+s23*s23+\
          s31*s31+s32*s32+s33*s33

   sdijsdij=sd11*sd11+sd12*sd12+sd13*sd13+\
          sd21*sd21+sd22*sd22+sd23*sd23+\
          sd31*sd31+sd32*sd32+sd33*sd33

   cm=10.6*0.1**2

   term1=sdijsdij**1.5/(sijsij**2.5+sdijsdij**1.25)

   visold= vis3d
   delta=vol**0.333333
   vis3d= (cm*delta)**2*term1+viscos
#            under-relax viscosity
   vis3d= urfvis*vis3d+(1.-urfvis)*visold
   return vis3d

init()

# initialization
itstep_average_counter=0 # counter for timeaveraging
u3d=np.ones((ni,nj,nk))*1e-20
v3d=np.ones((ni,nj,nk))*1e-20
w3d=np.ones((ni,nj,nk))*1e-20
p3d=np.ones((ni,nj,nk))*1e-20
k3d=np.ones((ni,nj,nk))*1
eps3d=np.ones((ni,nj,nk))*1
om3d=np.ones((ni,nj,nk))*1
vis3d=np.ones((ni,nj,nk))*viscos

fk3d=np.ones((ni,nj,nk))

dpdx_old=np.ones((ni,nj,nk))*1e-20
dpdy_old=np.ones((ni,nj,nk))*1e-20
dpdz_old=np.ones((ni,nj,nk))*1e-20

convw=np.ones((ni+1,nj,nk))*1e-20
convs=np.ones((ni,nj+1,nk))*1e-20
convl=np.ones((ni,nj,nk+1))*1e-20

u3d_mean=np.ones((ni,nj,nk))*1e-20
v3d_mean=np.ones((ni,nj,nk))*1e-20
w3d_mean=np.ones((ni,nj,nk))*1e-20
p3d_mean=np.ones((ni,nj,nk))*1e-20
k3d_mean=np.ones((ni,nj,nk))*1e-20
om3d_mean=np.ones((ni,nj,nk))*1e-20
eps3d_mean=np.ones((ni,nj,nk))*1e-20
uu3d_stress=np.ones((ni,nj,nk))*1e-20
vv3d_stress=np.ones((ni,nj,nk))*1e-20
ww3d_stress=np.ones((ni,nj,nk))*1e-20
uv3d_stress=np.ones((ni,nj,nk))*1e-20
fk3d_mean=np.ones((ni,nj,nk))*1e-20
vis3d_mean=np.ones((ni,nj,nk))*1e-20

aw3d=np.ones((ni,nj,nk))*1e-20
ae3d=np.ones((ni,nj,nk))*1e-20
as3d=np.ones((ni,nj,nk))*1e-20
an3d=np.ones((ni,nj,nk))*1e-20
al3d=np.ones((ni,nj,nk))*1e-20
ah3d=np.ones((ni,nj,nk))*1e-20
ap3d=np.ones((ni,nj,nk))*1e-20
apo3d=np.ones((ni,nj,nk))*1e-20
su3d=np.ones((ni,nj,nk))*1e-20
sp3d=np.ones((ni,nj,nk))*1e-20
dudx=np.ones((ni,nj,nk))*1e-20
dudy=np.ones((ni,nj,nk))*1e-20
usynt_inlet=np.ones((nj,nk))*1e-20
vsynt_inlet=np.ones((nj,nk))*1e-20
wsynt_inlet=np.ones((nj,nk))*1e-20

# comute Delta_max for LES/DES/PANS models
delta_max=np.maximum(deltae,deltan)
delta_max=np.maximum(delta_max,dz)

# initialize
u3d,v3d,w3d,k3d,om3d,eps3d,vis3d =modify_init(u3d,v3d,w3d,k3d,om3d,eps3d,vis3d)


# read data for restart
if restart: 
   u3d,v3d,w3d,p3d,k3d,eps3d,om3d= read_restart_data(u3d,v3d,w3d,p3d,k3d,eps3d,om3d)

u3d_face_w,u3d_face_s,u3d_face_l=compute_face_phi(u3d,u_bc_west,u_bc_east,u_bc_south,u_bc_north,u_bc_z,\
    u_bc_west_type,u_bc_east_type,u_bc_south_type,u_bc_north_type,u_bc_z_type)
v3d_face_w,v3d_face_s,v3d_face_l=compute_face_phi(v3d,v_bc_west,v_bc_east,v_bc_south,v_bc_north,v_bc_z,\
    v_bc_west_type,v_bc_east_type,v_bc_south_type,v_bc_north_type,v_bc_z_type)
w3d_face_w,w3d_face_s,w3d_face_l=compute_face_phi(w3d,w_bc_west,w_bc_east,w_bc_south,w_bc_north,w_bc_z,\
    w_bc_west_type,w_bc_east_type,w_bc_south_type,w_bc_north_type,w_bc_z_type)
p3d_face_w,p3d_face_s,p3d_face_l=compute_face_phi(p3d,p_bc_west,p_bc_east,p_bc_south,p_bc_north,p_bc_z,\
    p_bc_west_type,p_bc_east_type,p_bc_south_type,p_bc_north_type,p_bc_z_type)


u3d_old,v3d_old,w3d_old,k3d_old,eps3d_old,om3d_old,dpdx_old,dpdy_old,dpdz_old=update(u3d,v3d,w3d,k3d,eps3d,om3d,p3d_face_w,p3d_face_s,p3d_face_l)

epsmin=np.min(eps3d.flatten())
kmin=np.min(k3d.flatten())

itstep=0
iter=0


print('kmin,epsmin',kmin,epsmin)

if kom or kom_des:
   urf_temp=urfvis # no under-relaxation
   urfvis=1
   vis3d=vist_kom(vis3d,k3d,om3d)
   urfvis=urf_temp

if pans or keps:
   if pans:
     fk3d=compute_fk(k3d,eps3d)
# compute fmu3d
   gen=np.zeros((ni,nj,nk))
   itstep=1
   itstep=0
   su3d,sp3d,fmu3d= calceps_ls(su3d,sp3d,k3d,eps3d,vis3d,gen,dudx,dudy)
   urf_temp=urfvis # no under-relaxation
   urfvis=1
   vis3d=vist_pans(vis3d,k3d,eps3d,fmu3d)
   urfvis=urf_temp

if smag:
   urf_temp=urfvis # no under-relaxation
   urfvis=1
   vis3d=vist_smag(u3d_face_w,u3d_face_s,u3d_face_l,v3d_face_w,v3d_face_s,v3d_face_l,w3d_face_w,w3d_face_s,w3d_face_l,vis3d)
   urfvis=urf_temp

itstep=0
convw,convs,convl=conv(u3d,v3d,w3d,p3d_face_w,p3d_face_s,p3d_face_l)

iter=0
itstep=0


# find max index
#sumax=np.max(su3d.flatten())
#print('[i,j,k]', np.where(su3d == np.amax(su3d)) 

residual_u=0
residual_v=0
residual_w=0
residual_p=0
residual_k=0
residual_eps=0
residual_om=0



for itstep in range(0,ntstep):
   for iter in range(0,maxit):

      start_time_iter = time.time()
# coefficients for velocities
      start_time = time.time()
# conpute inlet fluc
      if iter == 0:
         u_bc_west,v_bc_west,w_bc_west,u3d_face_w,convw = compute_inlet_fluct(u_bc_west,v_bc_west,w_bc_west,u3d_face_w,convw)
      aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d=coeff(convw,convs,convl,vis3d,1,scheme)
# u3d
# boundary conditions for u3d
      su3d,sp3d=bc(su3d,sp3d,u_bc_west,u_bc_east,u_bc_south,u_bc_north,u_bc_z, \
                   u_bc_west_type,u_bc_east_type,u_bc_south_type,u_bc_north_type,u_bc_z_type)
      su3d,sp3d=calcu(su3d,sp3d,dpdx_old,p3d_face_w,p3d_face_s)
      ap3d,su3d=crank_nicol(u3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv)

      u3d,residual_u=solve_3d(u3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_vel,nsweep_vel,acrank_conv)
      print('time u',time.time()-start_time)

# v3d
# boundary conditions for v3d
      su3d,sp3d=bc(su3d,sp3d,v_bc_west,v_bc_east,v_bc_south,v_bc_north,v_bc_z, \
                   v_bc_west_type,v_bc_east_type,v_bc_south_type,v_bc_north_type,v_bc_z_type)
      su3d,sp3d=calcv(su3d,sp3d,dpdy_old,p3d_face_w,p3d_face_s)
      ap3d,su3d=crank_nicol(v3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv)
      v3d,residual_v=solve_3d(v3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_vel,nsweep_vel,acrank_conv)

# w3d
# boundary conditions for w3d
      su3d,sp3d=bc(su3d,sp3d,w_bc_west,w_bc_east,w_bc_south,w_bc_north,w_bc_z, \
                   w_bc_west_type,w_bc_east_type,w_bc_south_type,w_bc_north_type,w_bc_z_type)
      su3d,sp3d=calcw(su3d,sp3d,dpdz_old,p3d_face_l)

      
      ap3d,su3d=crank_nicol(w3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv)
      w3d,residual_w=solve_3d(w3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_vel,nsweep_vel,acrank_conv)

# p3d
      convw,convs,convl=conv(u3d,v3d,w3d,p3d_face_w,p3d_face_s,p3d_face_l)
      convw,u_bc_east =bc_outlet(convw)


# RHS
# continuity error
      su3d=(convw[0:-1,:,:]-np.roll(convw[0:-1,:,:],-1,axis=0)\
        +convs[:,0:-1,:]-np.roll(convs[:,0:-1,:],-1,axis=1)\
        +convl[:,:,0:-1]-np.roll(convl[:,:,0:-1],-1,axis=2))/acrank/dt[itstep]


#
      if iter == 0 and itstep == 0:
         aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d=calcp(convw,convs,convl)
         aw3d_p=aw3d
         as3d_p=as3d
         al3d_p=al3d


      p3d=solve_p(p3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_p)

# correct u, v, w, p
      convw,convs,convl,p3d,u3d,v3d,w3d,su3d= correct_conv(u3d,v3d,w3d,p3d,aw3d_p,as3d_p,al3d_p)
      res_1d=np.matrix.flatten(su3d)
      residual_p=np.linalg.norm(res_1d,ord=1)

      u3d_face_w,u3d_face_s,u3d_face_l=compute_face_phi(u3d,u_bc_west,u_bc_east,u_bc_south,u_bc_north,u_bc_z,\
        u_bc_west_type,u_bc_east_type,u_bc_south_type,u_bc_north_type,u_bc_z_type)
      v3d_face_w,v3d_face_s,v3d_face_l=compute_face_phi(v3d,v_bc_west,v_bc_east,v_bc_south,v_bc_north,v_bc_z,\
        v_bc_west_type,v_bc_east_type,v_bc_south_type,v_bc_north_type,v_bc_z_type)
      w3d_face_w,w3d_face_s,w3d_face_l=compute_face_phi(w3d,w_bc_west,w_bc_east,w_bc_south,w_bc_north,w_bc_z,\
        w_bc_west_type,w_bc_east_type,w_bc_south_type,w_bc_north_type,w_bc_z_type)
      p3d_face_w,p3d_face_s,p3d_face_l=compute_face_phi(p3d,p_bc_west,p_bc_east,p_bc_south,p_bc_north,p_bc_z,\
        p_bc_west_type,p_bc_east_type,p_bc_south_type,p_bc_north_type,p_bc_z_type)

      if kom or kom_des:
         vis3d=vist_kom(vis3d,k3d,om3d)
# coefficients
         start_time = time.time()
         aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d=coeff(convw,convs,convl,vis3d,prand_k,scheme_keps)
# k
# boundary conditions for k3d
         su3d,sp3d=bc(su3d,sp3d,k_bc_west,k_bc_east,k_bc_south,k_bc_north,k_bc_z, \
                   k_bc_west_type,k_bc_east_type,k_bc_south_type,k_bc_north_type,k_bc_z_type)
         su3d,sp3d,gen,fk3d=calck_kom(su3d,sp3d,k3d,om3d,vis3d,u3d_face_w,u3d_face_s,v3d_face_w,v3d_face_s,w3d_face_l)

         ap3d,su3d=crank_nicol(k3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv_kom)

         k3d,residual_k=solve_3d(k3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_k,nsweep_kom,acrank_conv_kom)
         k3d=np.maximum(k3d,1e-10)
         print('time k',time.time()-start_time)


# omega
# boundary conditions for om3d
         aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d=coeff(convw,convs,convl,vis3d,prand_k,scheme_keps)
         su3d,sp3d=bc(su3d,sp3d,om_bc_west,om_bc_east,om_bc_south,om_bc_north,om_bc_z, \
                   om_bc_west_type,om_bc_east_type,om_bc_south_type,om_bc_north_type,om_bc_z_type)
         su3d,sp3d= calcom(su3d,sp3d,om3d,gen)
         ap3d,su3d=crank_nicol(om3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv_kom)

         om3d,residual_om=solve_3d(om3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_om,nsweep_kom,acrank_conv_kom)
         om3d=np.maximum(om3d,1e-10)


      if smag:
         vis3d=vist_smag(u3d_face_w,u3d_face_s,u3d_face_l,v3d_face_w,v3d_face_s,v3d_face_l,w3d_face_w,w3d_face_s,w3d_face_l,vis3d)
      if pans or keps:
         vis3d=vist_pans(vis3d,k3d,eps3d,fmu3d)
# coefficients
         start_time = time.time()
         aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d=coeff(convw,convs,convl,vis3d,prand_k,scheme_keps)
# k
# boundary conditions for u3d
         su3d,sp3d=bc(su3d,sp3d,k_bc_west,k_bc_east,k_bc_south,k_bc_north,k_bc_z, \
                   k_bc_west_type,k_bc_east_type,k_bc_south_type,k_bc_north_type,k_bc_z_type)
         su3d,sp3d,gen,dudx,dudy=calck_ls(su3d,sp3d,k3d,eps3d,vis3d,u3d_face_w,u3d_face_s,v3d_face_w,v3d_face_s,w3d_face_l)

         ap3d,su3d=crank_nicol(k3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv_keps)

         k3d,residual_k=solve_3d(k3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_k,nsweep_keps,acrank_conv_keps)
         k3d=np.maximum(k3d,1e-6)
         print('time k',time.time()-start_time)

# eps
         aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d=coeff(convw,convs,convl,vis3d,prand_eps,scheme_keps)
# boundary conditions for u3d
         su3d,sp3d=bc(su3d,sp3d,eps_bc_west,eps_bc_east,eps_bc_south,eps_bc_north,eps_bc_z, \
                   eps_bc_west_type,eps_bc_east_type,eps_bc_south_type,eps_bc_north_type,eps_bc_z_type)
         su3d,sp3d,fmu3d= calceps_ls(su3d,sp3d,k3d,eps3d,vis3d,gen,dudx,dudy)

         ap3d,su3d=crank_nicol(eps3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv_keps)

         eps3d,residual_eps=solve_3d(eps3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_eps,nsweep_keps,acrank_conv_keps)
         eps3d=np.maximum(eps3d,1e-6)

         if pans:
            fk3d=compute_fk(k3d,eps3d)

# scale residuals
      residual_u=residual_u/resnorm_vel
      residual_v=residual_v/resnorm_vel
      residual_w=residual_w/resnorm_vel
      residual_p=residual_p/resnorm_p
      residual_k=residual_k/resnorm_vel**2
      residual_eps=residual_eps/resnorm_vel**3
      residual_om=residual_om/resnorm_vel

      resmax=np.max([residual_u ,residual_v,residual_p,residual_k,residual_eps,residual_om])

      print('-time step: %d, iter: %d, max residul=%10.2E, u=%10.2E, v=%10.2E,\
w=%10.2E, cont=%10.2E, k=%10.2E, eps=%10.2E, om =%10.2E\n\n'\
      % (itstep,iter, resmax,residual_u, residual_v, residual_w, residual_p, residual_k, residual_eps, residual_om))

      print('monitor --- -time step: %d, iter: %d, u=%10.2E, v=%10.2E, w=%10.2E, p=%10.2E, \
k=%10.2E, eps=%10.2E, om=%10.2E, vis=%10.2E,\n\n'\
      % (itstep,iter,u3d[imon,jmon,kmon],v3d[imon,jmon,kmon],w3d[imon,jmon,kmon],p3d[imon,jmon,kmon],\
          k3d[imon,jmon,kmon],eps3d[imon,jmon,kmon],om3d[imon,jmon,kmon],vis3d[imon,jmon,kmon]))


      vismax=np.max(vis3d.flatten())/viscos
      umax=np.max(u3d.flatten())
      epsmin=np.min(eps3d.flatten())
      ommin=np.min(om3d.flatten())

      kmin=np.min(k3d.flatten())
      dx=x2d[1,0]-x2d[0,0]
      cfl=umax*dt[itstep]/dx

      print('-time step: %d, dt: %8.2E, iter: %d, umax=%8.2E, cfl = %8.2E, vismax = %8.2E, kmin = %8.2E, epsmin  %8.2E, ommin  %8.2E\n\n'\
      % (itstep,dt[itstep],iter, umax,cfl,vismax,kmin,epsmin,ommin))
 
      if iter > 0 and resmax < sormax:  
#     if resmax < sormax:  

         break


   u3d_old,v3d_old,w3d_old,k3d_old,eps3d_old,om3d_old,dpdx_old,dpdy_old,dpdz_old=\
     update(u3d,v3d,w3d,k3d,eps3d,om3d,p3d_face_w,p3d_face_s,p3d_face_l)
# save data every itstep_save timsstep
   if itstep%itstep_save == 0:
      save_time_aver_data(u3d_mean,v3d_mean,w3d_mean,p3d_mean,eps3d_mean,om3d_mean,fk3d_mean,vis3d_mean,k3d_mean,uu3d_stress,\
                          vv3d_stress,ww3d_stress,uv3d_stress)
   if save and itstep%itstep_save == 0:
      save_data(u3d,v3d,w3d,p3d,k3d,eps3d,om3d)

   if itstep >= itstep_start and itstep % itstep_average == 0:
      u3d_mean,v3d_mean,w3d_mean,p3d_mean,k3d_mean,eps3d_mean,om3d_mean,uu3d_stress,vv3d_stress,ww3d_stress,uv3d_stress,\
          fk3d_mean,vis3d_mean= time_average(u3d_mean,v3d_mean,w3d_mean,p3d_mean,k3d_mean,eps3d_mean,om3d_mean,uu3d_stress,\
          vv3d_stress,ww3d_stress,uv3d_stress,fk3d_mean,vis3d_mean)
   print('time one iteration',time.time()-start_time_iter)
      
# save data for restart
if save:
   save_data(u3d,v3d,w3d,p3d,k3d,eps3d,om3d)

# save time-averaged data
save_time_aver_data(u3d_mean,v3d_mean,w3d_mean,p3d_mean,eps3d_mean,om3d_mean,fk3d_mean,vis3d_mean,k3d_mean,uu3d_stress,\
                    vv3d_stress,ww3d_stress,uv3d_stress)

print('program reached normal stop')

